{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "213d538c",
   "metadata": {},
   "source": [
    "# T3. dataloader 的内部结构和基本使用\n",
    "\n",
    "&emsp; 1 &ensp; fastNLP 中的 dataloader\n",
    " \n",
    "&emsp; &emsp; 1.1 &ensp; dataloader 的基本介绍\n",
    "\n",
    "&emsp; &emsp; 1.2 &ensp; dataloader 的函数创建\n",
    "\n",
    "&emsp; 2 &ensp; fastNLP 中 dataloader 的延伸\n",
    "\n",
    "&emsp; &emsp; 2.1 &ensp; collator 的概念与使用\n",
    "\n",
    "&emsp; &emsp; 2.2 &ensp; 结合 datasets 框架"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85857115",
   "metadata": {},
   "source": [
    "## 1. fastNLP 中的 dataloader\n",
    "\n",
    "### 1.1 dataloader 的基本介绍\n",
    "\n",
    "在`fastNLP 1.0`的开发中，最关键的开发目标就是**实现 fastNLP 对当前主流机器学习框架**，例如\n",
    "\n",
    "&emsp; **当下流行的 pytorch**，以及**国产的 paddle 、jittor 和 oneflow 的兼容**，扩大受众的同时，也是助力国产\n",
    "\n",
    "本着分而治之的思想，我们可以将`fastNLP 1.0`对`pytorch`、`paddle`、`jittor`、`oneflow`框架的兼容，划分为\n",
    "\n",
    "&emsp; &emsp; **对数据预处理**、**批量 batch 的划分与补齐**、**模型训练**、**模型评测**，**四个部分的兼容**\n",
    "\n",
    "&emsp; 针对数据预处理，我们已经在`tutorial-1`中介绍了`dataset`和`vocabulary`的使用\n",
    "\n",
    "&emsp; &emsp; 而结合`tutorial-0`，我们可以发现**数据预处理环节本质上是框架无关的**\n",
    "\n",
    "&emsp; &emsp; 因为在不同框架下，读取的原始数据格式都差异不大，彼此也很容易转换\n",
    "\n",
    "只有涉及到张量、模型，不同框架才展现出其各自的特色：**pytorch 和 oneflow 中的 tensor 和 nn.Module**\n",
    "\n",
    "&emsp; &emsp; **在 paddle 中称为 tensor 和 nn.Layer**，**在 jittor 中则称为 Var 和 Module**\n",
    "\n",
    "&emsp; &emsp; 因此，**模型训练、模型评测**，**是兼容的重难点**，我们将会在`tutorial-5`中详细介绍\n",
    "\n",
    "&emsp; 针对批量`batch`的处理，作为`fastNLP 1.0`中框架无关部分想框架相关部分的过渡\n",
    "\n",
    "&emsp; &emsp; 就是`dataloader`模块的职责，这也是本篇教程`tutorial-3`讲解的重点\n",
    "\n",
    "**dataloader 模块的职责**，详细划分可以包含以下三部分，**采样划分、补零对齐、框架匹配**\n",
    "\n",
    "&emsp; &emsp; 第一，确定`batch`大小，确定采样方式，划分后通过迭代器即可得到`batch`序列\n",
    "\n",
    "&emsp; &emsp; 第二，对于序列处理，这也是`fastNLP`主要针对的，将同个`batch`内的数据对齐\n",
    "\n",
    "&emsp; &emsp; 第三，**batch 内数据格式要匹配框架**，**但 batch 结构需保持一致**，**参数匹配机制**\n",
    "\n",
    "&emsp; 对此，`fastNLP 1.0`给出了 **TorchDataLoader 、 PaddleDataLoader 、 JittorDataLoader 和 OneflowDataLoader**\n",
    "\n",
    "&emsp; &emsp; 分别针对并匹配不同框架，但彼此之间参数名、属性、方法仍然类似，前两者大致如下表所示\n",
    "\n",
    "名称|参数|属性|功能|内容\n",
    "----|----|----|----|----|\n",
    " `dataset` | √ | √ | 指定`dataloader`的数据内容  |  |\n",
    " `batch_size` | √ | √ | 指定`dataloader`的`batch`大小 | 默认`16` |\n",
    " `shuffle` | √ | √ | 指定`dataloader`的数据是否打乱 | 默认`False` |\n",
    " `collate_fn` | √ | √ | 指定`dataloader`的`batch`打包方法 | 视框架而定 |\n",
    " `sampler` | √ | √ | 指定`dataloader`的`__len__`和`__iter__`函数的实现 | 默认`None` |\n",
    " `batch_sampler` | √ | √ | 指定`dataloader`的`__len__`和`__iter__`函数的实现 | 默认`None` |\n",
    " `drop_last` | √ | √ | 指定`dataloader`划分`batch`时是否丢弃剩余的 | 默认`False` |\n",
    " `cur_batch_indices` |  | √ | 记录`dataloader`当前遍历批量序号 |  |\n",
    " `num_workers` | √ | √ | 指定`dataloader`开启子进程数量 | 默认`0` |\n",
    " `worker_init_fn` | √ | √ | 指定`dataloader`子进程初始方法 | 默认`None` |\n",
    " `generator` | √ | √ | 指定`dataloader`子进程随机种子 | 默认`None` |\n",
    " `prefetch_factor` |  | √ | 指定为每个`worker`装载的`sampler`数量 | 默认`2` |"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60a8a224",
   "metadata": {},
   "source": [
    "&emsp; 论及`dataloader`的函数，其中，`get_batch_indices`用来获取当前遍历到的`batch`序号，其他函数\n",
    "\n",
    "&emsp; &emsp; 包括`set_ignore`、`set_pad`和`databundle`类似，请参考`tutorial-2`，此处不做更多介绍\n",
    "\n",
    "&emsp; &emsp; 以下是`tutorial-2`中已经介绍过的数据预处理流程，接下来是对相关数据进行`dataloader`处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "aca72b49",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[38;5;2m[i 0604 15:44:29.773860 92 log.cc:351] Load log_sync: 1\u001b[m\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Processing:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Processing:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Processing:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n",
      "| SentenceId | Sentence       | Sentiment | input_ids      | token_type_ids     | attention_mask     | target |\n",
      "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n",
      "| 1          | A series of... | negative  | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1      |\n",
      "| 4          | A positivel... | neutral   | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 2      |\n",
      "| 3          | Even fans o... | negative  | [101, 2130,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1      |\n",
      "| 5          | A comedy-dr... | positive  | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 0      |\n",
      "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import pandas as pd\n",
    "from functools import partial\n",
    "from fastNLP.transformers.torch import BertTokenizer\n",
    "\n",
    "from fastNLP import DataSet\n",
    "from fastNLP import Vocabulary\n",
    "from fastNLP.io import DataBundle\n",
    "\n",
    "\n",
    "class PipeDemo:\n",
    "    def __init__(self, tokenizer='bert-base-uncased'):\n",
    "        self.tokenizer = BertTokenizer.from_pretrained(tokenizer)\n",
    "\n",
    "    def process_from_file(self, path='./data/test4dataset.tsv'):\n",
    "        datasets = DataSet.from_pandas(pd.read_csv(path, sep='\\t'))\n",
    "        train_ds, test_ds = datasets.split(ratio=0.7)\n",
    "        train_ds, dev_ds = datasets.split(ratio=0.8)\n",
    "        data_bundle = DataBundle(datasets={'train': train_ds, 'dev': dev_ds, 'test': test_ds})\n",
    "\n",
    "        encode = partial(self.tokenizer.encode_plus, max_length=100, truncation=True,\n",
    "                         return_attention_mask=True)\n",
    "        data_bundle.apply_field_more(encode, field_name='Sentence', progress_bar='tqdm')\n",
    "        \n",
    "        target_vocab = Vocabulary(padding=None, unknown=None)\n",
    "\n",
    "        target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment')\n",
    "        target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment',\n",
    "                                   new_field_name='target')\n",
    "\n",
    "        data_bundle.set_pad('input_ids', pad_val=self.tokenizer.pad_token_id)\n",
    "        data_bundle.set_ignore('SentenceId', 'Sentence', 'Sentiment')  \n",
    "        return data_bundle\n",
    "\n",
    "    \n",
    "pipe = PipeDemo(tokenizer='bert-base-uncased')\n",
    "\n",
    "data_bundle = pipe.process_from_file('./data/test4dataset.tsv')\n",
    "\n",
    "print(data_bundle.get_dataset('train'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76e6b8ab",
   "metadata": {},
   "source": [
    "### 1.2 dataloader 的函数创建\n",
    "\n",
    "在`fastNLP 1.0`中，**更方便、可能更常用的 dataloader 创建方法是通过 prepare_xx_dataloader 函数**\n",
    "\n",
    "&emsp; 例如下方的`prepare_torch_dataloader`函数，指定必要参数，读取数据集，生成对应`dataloader`\n",
    "\n",
    "&emsp; 类型为`TorchDataLoader`，只能适用于`pytorch`框架，因此对应`trainer`初始化时`driver='torch'`\n",
    "\n",
    "同时我们看还可以发现，在`fastNLP 1.0`中，**batch 表示为字典 dict 类型**，**key 值就是原先数据集中各个字段**\n",
    "\n",
    "&emsp; **除去经过 DataBundle.set_ignore 函数隐去的部分**，而`value`值为`pytorch`框架对应的`torch.Tensor`类型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5fd60e42",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'fastNLP.core.dataloaders.torch_dataloader.fdl.TorchDataLoader'>\n",
      "<class 'dict'> <class 'torch.Tensor'> ['input_ids', 'token_type_ids', 'attention_mask', 'target']\n",
      "{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),\n",
      " 'input_ids': tensor([[  101,  1037,  4038,  1011,  3689,  1997,  3053,  8680, 19173, 15685,\n",
      "          1999,  1037, 18006,  2836,  2011,  1996,  2516,  2839, 14996,  3054,\n",
      "         15509,  5325,  1012,   102,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0],\n",
      "        [  101,  1037,  2186,  1997,  9686, 17695, 18673, 14313,  1996, 15262,\n",
      "          3351,  2008,  2054,  2003,  2204,  2005,  1996, 13020,  2003,  2036,\n",
      "          2204,  2005,  1996, 25957,  4063,  1010,  2070,  1997,  2029,  5681,\n",
      "          2572, 25581,  2021,  3904,  1997,  2029,  8310,  2000,  2172,  1997,\n",
      "          1037,  2466,  1012,   102],\n",
      "        [  101,  2130,  4599,  1997, 19214,  6432,  1005,  1055,  2147,  1010,\n",
      "          1045,  8343,  1010,  2052,  2031,  1037,  2524,  2051,  3564,  2083,\n",
      "          2023,  2028,  1012,   102,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0],\n",
      "        [  101,  1037, 13567, 26162,  5257,  1997,  3802,  7295,  9888,  1998,\n",
      "          2035,  1996, 20014, 27611,  1010, 14583,  1010, 11703, 20175,  1998,\n",
      "          4028,  1997,  1037,  8101,  2319, 10576,  2030,  1037, 28900,  7815,\n",
      "          3850,  1012,   102,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0]]),\n",
      " 'target': tensor([0, 1, 1, 2]),\n",
      " 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
      "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
      "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
      "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n"
     ]
    }
   ],
   "source": [
    "from fastNLP import prepare_torch_dataloader\n",
    "\n",
    "train_dataset = data_bundle.get_dataset('train')\n",
    "evaluate_dataset = data_bundle.get_dataset('dev')\n",
    "\n",
    "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n",
    "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)\n",
    "\n",
    "print(type(train_dataloader))\n",
    "\n",
    "import pprint\n",
    "\n",
    "for batch in train_dataloader:\n",
    "    print(type(batch), type(batch['input_ids']), list(batch))\n",
    "    pprint.pprint(batch, width=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f457a6e",
   "metadata": {},
   "source": [
    "之所以说`prepare_xx_dataloader`函数更方便，是因为其**导入对象不仅可也是 DataSet 类型**，**还可以**\n",
    "\n",
    "&emsp; **是 DataBundle 类型**，不过数据集名称需要是`'train'`、`'dev'`、`'test'`供`fastNLP`识别\n",
    "\n",
    "例如下方就是**直接通过 prepare_paddle_dataloader 函数生成基于 PaddleDataLoader 的字典**\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7827557d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'fastNLP.core.dataloaders.paddle_dataloader.fdl.PaddleDataLoader'>\n"
     ]
    }
   ],
   "source": [
    "from fastNLP import prepare_paddle_dataloader\n",
    "\n",
    "dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)\n",
    "\n",
    "print(type(dl_bundle['train']))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d898cf40",
   "metadata": {},
   "source": [
    "&emsp; 而在接下来`trainer`的初始化过程中，按如下方式使用即可，除了初始化时`driver='paddle'`外\n",
    "\n",
    "&emsp; 这里也可以看出`trainer`模块中，**evaluate_dataloaders 的设计允许评测可以针对多个数据集**\n",
    "\n",
    "```python\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    train_dataloader=dl_bundle['train'],\n",
    "    optimizers=optimizer,\n",
    "\t...\n",
    "\tdriver='paddle',\n",
    "\tdevice='gpu',\n",
    "\t...\n",
    "    evaluate_dataloaders={'dev': dl_bundle['dev'], 'test': dl_bundle['test']},     \n",
    "    metrics={'acc': Accuracy()},\n",
    "\t...\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d74d0523",
   "metadata": {},
   "source": [
    "## 2. fastNLP 中 dataloader 的延伸\n",
    "\n",
    "### 2.1 collator 的概念与使用\n",
    "\n",
    "在`fastNLP 1.0`中，在数据加载模块`dataloader`内部，如之前表格所列举的，还存在其他的一些模块\n",
    "\n",
    "&emsp; 例如，**实现序列的补零对齐的核对器 collator 模块**；注：`collate vt. 整理（文件或书等）；核对，校勘`\n",
    "\n",
    "在`fastNLP 1.0`中，虽然`dataloader`随框架不同，但`collator`模块却是统一的，主要属性、方法如下表所示\n",
    "\n",
    "名称|属性|方法|功能|内容\n",
    " ----|----|----|----|----|\n",
    " `backend` | √ |  | 记录`collator`对应框架 | 字符串型，如`'torch'` |\n",
    " `padders` | √ |  | 记录各字段对应的`padder`，每个负责具体补零对齐&emsp; | 字典类型 |\n",
    " `ignore_fields` | √ |  | 记录`dataloader`采样`batch`时不予考虑的字段 | 集合类型 |\n",
    " `input_fields` | √ |  | 记录`collator`每个字段的补零值、数据类型等 | 字典类型 |\n",
    " `set_backend` |  | √ | 设置`collator`对应框架 | 字符串型，如`'torch'` |\n",
    " `set_ignore` |  | √ | 设置`dataloader`采样`batch`时不予考虑的字段 | 字符串型，表示`field_name`&emsp; |\n",
    " `set_pad` |  | √ | 设置`collator`每个字段的补零值、数据类型等 |  |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d0795b3e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'function'>\n"
     ]
    }
   ],
   "source": [
    "train_dataloader.collate_fn\n",
    "\n",
    "print(type(train_dataloader.collate_fn))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f816ef5",
   "metadata": {},
   "source": [
    "此外，还可以 **手动定义 dataloader 中的 collate_fn**，而不是使用`fastNLP 1.0`中自带的`collator`模块\n",
    "\n",
    "&emsp; 该函数的定义可以大致如下，需要注意的是，**定义 collate_fn 之前需要了解 batch 作为字典的格式**\n",
    "\n",
    "&emsp; 该函数通过`collate_fn`参数传入`dataloader`，**在 batch 分发**（**而不是 batch 划分**）**时调用**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ff8e405e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def collate_fn(batch):\n",
    "    input_ids, atten_mask, labels = [], [], []\n",
    "    max_length = [0] * 3\n",
    "    for each_item in batch:\n",
    "        input_ids.append(each_item['input_ids'])\n",
    "        max_length[0] = max(len(each_item['input_ids']), max_length[0])\n",
    "        atten_mask.append(each_item['token_type_ids'])\n",
    "        max_length[1] = max(len(each_item['token_type_ids']), max_length[1])\n",
    "        labels.append(each_item['attention_mask'])\n",
    "        max_length[2] = max(len(each_item['attention_mask']), max_length[2])\n",
    "\n",
    "    for i in range(3):\n",
    "        each = (input_ids, atten_mask, labels)[i]\n",
    "        for item in each:\n",
    "            item.extend([0] * (max_length[i] - len(item)))\n",
    "    return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n",
    "            'token_type_ids': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n",
    "            'attention_mask': torch.cat([torch.tensor(item) for item in labels], dim=0)}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "487b75fb",
   "metadata": {},
   "source": [
    "注意：使用自定义的`collate_fn`函数，`trainer`的`collate_fn`变量也会自动调整为`function`类型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e916d1ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'fastNLP.core.dataloaders.torch_dataloader.fdl.TorchDataLoader'>\n",
      "<class 'function'>\n",
      "{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,\n",
      "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,\n",
      "        0, 0, 0, 0, 0, 0, 0, 0]),\n",
      " 'input_ids': tensor([[  101,  1037,  4038,  1011,  3689,  1997,  3053,  8680, 19173, 15685,\n",
      "          1999,  1037, 18006,  2836,  2011,  1996,  2516,  2839, 14996,  3054,\n",
      "         15509,  5325,  1012,   102,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0],\n",
      "        [  101,  1037,  2186,  1997,  9686, 17695, 18673, 14313,  1996, 15262,\n",
      "          3351,  2008,  2054,  2003,  2204,  2005,  1996, 13020,  2003,  2036,\n",
      "          2204,  2005,  1996, 25957,  4063,  1010,  2070,  1997,  2029,  5681,\n",
      "          2572, 25581,  2021,  3904,  1997,  2029,  8310,  2000,  2172,  1997,\n",
      "          1037,  2466,  1012,   102],\n",
      "        [  101,  2130,  4599,  1997, 19214,  6432,  1005,  1055,  2147,  1010,\n",
      "          1045,  8343,  1010,  2052,  2031,  1037,  2524,  2051,  3564,  2083,\n",
      "          2023,  2028,  1012,   102,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0],\n",
      "        [  101,  1037, 13567, 26162,  5257,  1997,  3802,  7295,  9888,  1998,\n",
      "          2035,  1996, 20014, 27611,  1010, 14583,  1010, 11703, 20175,  1998,\n",
      "          4028,  1997,  1037,  8101,  2319, 10576,  2030,  1037, 28900,  7815,\n",
      "          3850,  1012,   102,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0]]),\n",
      " 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
      "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
      "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
      "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n"
     ]
    }
   ],
   "source": [
    "train_dataloader = prepare_torch_dataloader(train_dataset, collate_fn=collate_fn, shuffle=True)\n",
    "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, collate_fn=collate_fn, shuffle=True)\n",
    "\n",
    "print(type(train_dataloader))\n",
    "print(type(train_dataloader.collate_fn))\n",
    "\n",
    "for batch in train_dataloader:\n",
    "    pprint.pprint(batch, width=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bd98365",
   "metadata": {},
   "source": [
    "### 2.2  fastNLP 与 datasets 的结合\n",
    "\n",
    "从`tutorial-1`至`tutorial-3`，我们已经完成了对`fastNLP v1.0`数据读取、预处理、加载，整个流程的介绍\n",
    "\n",
    "&emsp; 不过在实际使用中，我们往往也会采取更为简便的方法读取数据，例如使用`huggingface`的`datasets`模块\n",
    "\n",
    "**使用 datasets 模块中的 load_dataset 函数**，通过指定数据集两级的名称，示例中即是**GLUE 标准中的 SST-2 数据集**\n",
    "\n",
    "&emsp; 即可以快速从网上下载好`SST-2`数据集读入，之后以`pandas.DataFrame`作为中介，再转化成`fastNLP.DataSet`\n",
    "\n",
    "&emsp; 之后的步骤就和其他关于`dataset`、`databundle`、`vocabulary`、`dataloader`中介绍的相关使用相同了"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "91879c30",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "639a0ad3c63944c6abef4e8ee1f7bf7c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "sst2data = load_dataset('glue', 'sst2')\n",
    "\n",
    "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
